{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AttachExternModules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "import tempfile\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import tvm\n",
    "import tvm.testing\n",
    "from tvm import relax\n",
    "from tvm.relax.frontend import nn\n",
    "from tvm.relax.frontend.nn import spec\n",
    "from tvm.relax.transform import AttachExternModules\n",
    "\n",
    "def _compile_cc(src: Path, dst: Path):\n",
    "    # pylint: disable=import-outside-toplevel\n",
    "    from tvm._ffi.base import py_str\n",
    "    from tvm._ffi.libinfo import find_include_path\n",
    "\n",
    "    # pylint: enable=import-outside-toplevel\n",
    "\n",
    "    cmd = [\"g++\", str(src)]\n",
    "    for include_path in find_include_path():\n",
    "        cmd += [\"-I\", include_path]\n",
    "    cmd += [\n",
    "        \"-DDMLC_USE_FOPEN64=0\",\n",
    "        \"-DDMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>\",\n",
    "        \"-c\",\n",
    "        \"-fPIC\", # 生成位置无关代码 (-fPIC) 便于后续动态链接\n",
    "        \"-o\",\n",
    "        str(dst),\n",
    "    ]\n",
    "    with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as proc:\n",
    "        (out, _) = proc.communicate()\n",
    "        if proc.returncode != 0:\n",
    "            msg = \"Compilation error:\\n\"\n",
    "            msg += py_str(out)\n",
    "            msg += \"\\nCommand line: \" + \" \".join(cmd)\n",
    "            raise RuntimeError(msg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "推断测试："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _infer_scalar_add(x, y):  # pylint: disable=invalid-name\n",
    "    assert isinstance(x, nn.Tensor)\n",
    "    assert isinstance(y, nn.Tensor)\n",
    "    assert x.ndim == 0 and x.dtype == \"float32\"\n",
    "    assert y.ndim == 0 and y.dtype == \"float32\"\n",
    "    return nn.Tensor.placeholder(shape=(), dtype=\"float32\")\n",
    "\n",
    "\n",
    "def _infer_test_sym(a, b):  # pylint: disable=invalid-name\n",
    "    def _var_equal(a, b):  # pylint: disable=invalid-name\n",
    "        return tvm.ir.structural_equal(a, b, map_free_vars=True)\n",
    "\n",
    "    assert isinstance(a, nn.Tensor)\n",
    "    assert isinstance(b, nn.Tensor)\n",
    "    assert a.ndim == 3 and a.dtype == \"float32\"  # [x, y, 1]\n",
    "    assert b.ndim == 3 and b.dtype == \"float32\"  # [y, z, 5]\n",
    "    x, y, z = a.shape[0], b.shape[0], b.shape[1]  # pylint: disable=invalid-name\n",
    "    assert _var_equal(a.shape[0], x)\n",
    "    assert _var_equal(a.shape[1], y)\n",
    "    assert a.shape[2] == 1\n",
    "    assert _var_equal(b.shape[0], y)\n",
    "    assert _var_equal(b.shape[1], z)\n",
    "    assert b.shape[2] == 5\n",
    "    return nn.Tensor.placeholder(shape=(x, y, z, 9), dtype=\"float32\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _test_scalar_add(func):\n",
    "    # pylint: disable=invalid-name\n",
    "    x = tvm.nd.array(np.array(1.0).astype(\"float32\"))\n",
    "    y = tvm.nd.array(np.array(3.0).astype(\"float32\"))\n",
    "    z = func(x, y).numpy()\n",
    "    # pylint: enable=invalid-name\n",
    "    assert z.ndim == 0\n",
    "    assert z.dtype == \"float32\"\n",
    "    assert float(z) == 4.0\n",
    "\n",
    "\n",
    "def _test_infer_sym(func, x, y, z):  # pylint: disable=invalid-name\n",
    "    # pylint: disable=invalid-name\n",
    "    a = tvm.nd.array(np.random.uniform(size=(x, y, 1)).astype(\"float32\"))\n",
    "    b = tvm.nd.array(np.random.uniform(size=(y, z, 5)).astype(\"float32\"))\n",
    "    c = func(a, b).numpy()\n",
    "    # pylint: enable=invalid-name\n",
    "    assert c.shape == (x, y, z, 9)\n",
    "\n",
    "\n",
    "def _check_ir_equality(mod):\n",
    "    # pylint: disable=import-outside-toplevel\n",
    "    from tvm.script import ir as I\n",
    "    from tvm.script import relax as R\n",
    "    from tvm.script import tir as T\n",
    "\n",
    "    # pylint: enable=import-outside-toplevel\n",
    "\n",
    "    @I.ir_module\n",
    "    class ExpectedModule:\n",
    "        @R.function\n",
    "        def scalar_add(\n",
    "            a: R.Tensor((), dtype=\"float32\"), b: R.Tensor((), dtype=\"float32\")\n",
    "        ) -> R.Tensor((), dtype=\"float32\"):\n",
    "            R.func_attr({\"num_input\": 2})\n",
    "            with R.dataflow():\n",
    "                ext_scalar_add = R.call_dps_packed(\n",
    "                    \"ext_scalar_add\", (a, b), out_sinfo=R.Tensor((), dtype=\"float32\")\n",
    "                )\n",
    "                gv: R.Tensor((), dtype=\"float32\") = ext_scalar_add\n",
    "                R.output(gv)\n",
    "            return gv\n",
    "\n",
    "        @R.function\n",
    "        def test_sym(\n",
    "            a: R.Tensor((\"x\", \"y\", 1), dtype=\"float32\"), b: R.Tensor((\"y\", \"z\", 5), dtype=\"float32\")\n",
    "        ) -> R.Tensor((\"x\", \"y\", \"z\", 9), dtype=\"float32\"):\n",
    "            x = T.int64()\n",
    "            y = T.int64()\n",
    "            z = T.int64()\n",
    "            R.func_attr({\"num_input\": 2})\n",
    "            with R.dataflow():\n",
    "                ext_test_sym = R.call_dps_packed(\n",
    "                    \"ext_test_sym\", (a, b), out_sinfo=R.Tensor((x, y, z, 9), dtype=\"float32\")\n",
    "                )\n",
    "                gv1: R.Tensor((x, y, z, 9), dtype=\"float32\") = ext_test_sym\n",
    "                R.output(gv1)\n",
    "            return gv1\n",
    "\n",
    "    tvm.ir.assert_structural_equal(ExpectedModule, mod)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试外部对象"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tempfile.TemporaryDirectory() as temp_dir_str:\n",
    "    path = Path(temp_dir_str) / \"main.o\"\n",
    "    _compile_cc(\n",
    "        src=Path(\"__file__\").parent / \"frontend_nn_extern_module.cc\",\n",
    "        dst=path,\n",
    "    )\n",
    "\n",
    "    class TestModule(nn.Module):\n",
    "        def __init__(self):\n",
    "            self.ext_mod = None\n",
    "\n",
    "        def _get_ext_mod(self):\n",
    "            if self.ext_mod is None:\n",
    "                self.ext_mod = nn.ObjectModule(\n",
    "                    {\n",
    "                        \"ext_scalar_add\": _infer_scalar_add,\n",
    "                        \"ext_test_sym\": _infer_test_sym,\n",
    "                    },\n",
    "                    path,\n",
    "                )\n",
    "                nn.add_extern(self.ext_mod)\n",
    "            return self.ext_mod\n",
    "\n",
    "        def scalar_add(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name\n",
    "            return self._get_ext_mod()[\"ext_scalar_add\"](a, b)\n",
    "\n",
    "        def test_sym(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name\n",
    "            return self._get_ext_mod()[\"ext_test_sym\"](a, b)\n",
    "\n",
    "    mod, _, ext_mods = TestModule().export_tvm(\n",
    "        spec={\n",
    "            \"scalar_add\": {\n",
    "                \"a\": spec.Tensor((), \"float32\"),\n",
    "                \"b\": spec.Tensor((), \"float32\"),\n",
    "            },\n",
    "            \"test_sym\": {\n",
    "                \"a\": spec.Tensor((\"x\", \"y\", 1), \"float32\"),\n",
    "                \"b\": spec.Tensor((\"y\", \"z\", 5), \"float32\"),\n",
    "            },\n",
    "        },\n",
    "        allow_extern=True,\n",
    "    )\n",
    "    _check_ir_equality(mod)\n",
    "    mod = AttachExternModules(ext_mods)(mod)  # pylint: disable=not-callable\n",
    "    compiled = tvm.runtime.relax_vm.VirtualMachine(\n",
    "        tvm.compile(mod, target=\"llvm\"),\n",
    "        device=tvm.cpu(),\n",
    "    )\n",
    "    _test_scalar_add(compiled[\"scalar_add\"])\n",
    "    _test_infer_sym(compiled[\"test_sym\"], x=3, y=4, z=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试外部源码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "source = Path(\"__file__\").parent / \"frontend_nn_extern_module.cc\"\n",
    "\n",
    "class TestModule(nn.Module):\n",
    "    def __init__(self):\n",
    "        self.ext_mod = None\n",
    "\n",
    "    def _get_ext_mod(self):\n",
    "        if self.ext_mod is None:\n",
    "            self.ext_mod = nn.SourceModule(\n",
    "                {\n",
    "                    \"ext_scalar_add\": _infer_scalar_add,\n",
    "                    \"ext_test_sym\": _infer_test_sym,\n",
    "                },\n",
    "                source_code=source,\n",
    "                source_format=\"cpp\",\n",
    "            )\n",
    "            nn.add_extern(self.ext_mod)\n",
    "        return self.ext_mod\n",
    "\n",
    "    def scalar_add(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name\n",
    "        return self._get_ext_mod()[\"ext_scalar_add\"](a, b)\n",
    "\n",
    "    def test_sym(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name\n",
    "        return self._get_ext_mod()[\"ext_test_sym\"](a, b)\n",
    "\n",
    "mod, _, ext_mods = TestModule().export_tvm(\n",
    "    spec={\n",
    "        \"scalar_add\": {\n",
    "            \"a\": spec.Tensor((), \"float32\"),\n",
    "            \"b\": spec.Tensor((), \"float32\"),\n",
    "        },\n",
    "        \"test_sym\": {\n",
    "            \"a\": spec.Tensor((\"x\", \"y\", 1), \"float32\"),\n",
    "            \"b\": spec.Tensor((\"y\", \"z\", 5), \"float32\"),\n",
    "        },\n",
    "    },\n",
    "    allow_extern=True,\n",
    ")\n",
    "_check_ir_equality(mod)\n",
    "mod = AttachExternModules(ext_mods)(mod)  # pylint: disable=not-callable\n",
    "compiled = tvm.runtime.relax_vm.VirtualMachine(\n",
    "    tvm.compile(mod, target=\"llvm\"),\n",
    "    device=tvm.cpu(),\n",
    ")\n",
    "_test_scalar_add(compiled[\"scalar_add\"])\n",
    "_test_infer_sym(compiled[\"test_sym\"], x=3, y=4, z=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
